from typing import Any, Dict, List, Optional, Union
import torch
from torch.nn import Identity
import torch_geometric
from torch_geometric.nn.dense.linear import Linear

import torch.nn.functional as F
from torch import Tensor

class NewMLP(torch_geometric.nn.MLP):
    def __init__(
        self,
        channel_list: Optional[Union[List[int], int]] = None,
        *,
        in_channels: Optional[int] = None,
        hidden_channels: Optional[int] = None,
        out_channels: Optional[int] = None,
        num_layers: Optional[int] = None,
        dropout: float = 0.,
        act: str = "relu",
        batch_norm: bool = True,
        act_first: bool = False,
        act_kwargs: Optional[Dict[str, Any]] = None,
        batch_norm_kwargs: Optional[Dict[str, Any]] = None,
        bias: bool = True,
        relu_first: bool = False,
    ):
        super(torch_geometric.nn.MLP, self).__init__()

        from class_resolver.contrib.torch import activation_resolver

        act_first = act_first or relu_first
        batch_norm_kwargs = batch_norm_kwargs or {}

        if isinstance(channel_list, int):
            in_channels = channel_list

        if in_channels is not None:
            assert num_layers >= 1
            channel_list = [hidden_channels] * (num_layers - 1)
            channel_list = [in_channels] + channel_list + [out_channels]

        assert isinstance(channel_list, (tuple, list))
        assert len(channel_list) >= 2
        self.channel_list = channel_list

        self.dropout = dropout
        self.act = activation_resolver.make(act, act_kwargs)
        self.act_first = act_first

        self.lins = torch.nn.ModuleList()
        pairwise = zip(channel_list[:-1], channel_list[1:])
        for in_channels, out_channels in pairwise:
            self.lins.append(Linear(in_channels, out_channels, bias=bias))

        self.norms = torch.nn.ModuleList()
        for hidden_channels in channel_list[1:-1]:
            norm = Identity()
            self.norms.append(norm)

        self.reset_parameters()


    def forward(self, x: Tensor) -> Tensor:
        x = self.lins[0](x)
        for lin, norm in zip(self.lins[1:], self.norms):
            if self.act_first:
                x = self.act(x)
            if x.size()[0] == 1:
                x = norm(x)
            else:
                x = norm(x)
            if not self.act_first:
                x = self.act(x)
            x = F.dropout(x, p=self.dropout, training=self.training)
            x = lin.forward(x)
        return x
